import warnings
import pickle
from pathlib import Path
import yaml
import functools

import jax
import jax.numpy as jnp
from flax import serialization
import pandas as pd

import dreamerv3
from dreamerv3 import embodied
from dreamerv3.agent import ImagActorCritic
from dreamerv3.embodied.core.goal_sampler import GoalSampler, GoalSamplerCyclic
from baselines.qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from analysis.data_utils.analysis_repertoire import AnalysisRepertoire, AnalysisLatentRepertoire
from baselines.qdax import environments
from baselines.qdax.core.neuroevolution.networks.networks import MLPDC
from baselines.qdax.baselines.diayn_smerl import DIAYNSMERL, DiaynSmerlConfig
from baselines.qdax.core.neuroevolution.buffers.buffer import QDTransition
from baselines.qdax.tasks.brax_envs import reset_based_scoring_actor_dc_function_brax_envs as scoring_actor_dc_function
from baselines.qdax.environments import get_feat_mean
from baselines.qdax.utils.plotting import plot_2d_map_elites_repertoire

from omegaconf import OmegaConf
from utils import get_env
warnings.filterwarnings("ignore", ".*truncated to dtype int32.*")


task = "humanoid"

if task == "walker2d":
    ours_path = Path("/project/output/results/ours/walker2d_feet_contact/2023-09-14_195857_633124")
    smerl_path = Path("/project/output/results/smerl/walker2d_feet_contact/2023-09-17_234636_461181")
    smerl_reverse_path = Path("/project/output/results/smerl_reverse/walker2d_feet_contact/2023-09-18_032249_538981")
    uvfa_path = Path("/project/output/results/uvfa/walker2d_feet_contact/2023-09-24_152622_443300")
    dcg_me_path = Path("/project/output/results/dcg_me/walker2d_feet_contact/2023-09-14_170627_101109")
elif task == "humanoid":
    ours_path = Path("/project/output/results/ours/humanoid_feet_contact/")
    smerl_path = Path("/project/output/results/smerl/humanoid_feet_contact/")
    smerl_reverse_path = Path("/project/output/results/smerl_reverse/humanoid_feet_contact/")
    uvfa_path = Path("/project/output/results/uvfa/humanoid_feet_contact/")
    dcg_me_path = Path("/project/output/results/dcg_me/humanoid_feet_contact/")
else:
    raise NotImplementedError


def eval_ours(run_path, actuator_failure_idx, actuator_failure_strength):
    config_path = list((run_path / "wandb").iterdir())[0] / "files" / "config.yaml"
    with open(config_path) as f:
        config = yaml.safe_load(f)

    argv = [
    "--task={}".format(config["task"]["value"]),
    "--feat={}".format(config["feat"]["value"]),
    "--backend={}".format(config["backend"]["value"]),

    "--run.from_checkpoint={}".format(str(run_path / "checkpoint.ckpt")),
    "--envs.amount=2048",
    ]

    # Create config
    logdir = str(run_path)
    config = embodied.Config(dreamerv3.configs["defaults"])
    config = config.update(dreamerv3.configs["brax"])
    config = config.update({
    "logdir": logdir,
    "run.train_ratio": 32,
    "run.log_every": 60,  # Seconds
    "batch_size": 16,
    })
    config = embodied.Flags(config).parse(argv=argv)

    # Create logger
    logdir = embodied.Path(config.logdir)
    step = embodied.Counter()
    logger = embodied.Logger(step, [
    embodied.logger.TerminalOutput(),
    embodied.logger.JSONLOutput(logdir, "metrics.jsonl"),
    embodied.logger.TensorBoardOutput(logdir),
    # embodied.logger.WandBOutput(logdir, config),
    # embodied.logger.MLFlowOutput(logdir.name),
    ])

    # Create environment
    env = get_env(config, mode="train", actuator_failure_idx=actuator_failure_idx, actuator_failure_strength=actuator_failure_strength)

    # Create agent and replay buffer
    agent = dreamerv3.Agent(env.obs_space, env.act_space, env.feat_space, step, config)
    args = embodied.Config(
        **config.run, logdir=config.logdir,
        batch_steps=config.batch_size * config.batch_length)

    # Create goal sampler
    resolution = ImagActorCritic.get_resolution(env.feat_space, config)
    grid_shape = (resolution,) * env.feat_space['vector'].shape[0]
    goals = compute_euclidean_centroids(grid_shape, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    goal_sampler_cyclic = GoalSamplerCyclic(feat_space=env.feat_space, 
                                            goal_list=list(goals),
                                            number_visits_per_goal=n_visits_per_goal)
    embodied.run.eval_only(agent,
                           env,
                           goal_sampler=goal_sampler_cyclic,
                           period_sample_goals=float('inf'),
                           logger=logger,
                           args=args,)

    ours_repertoire = AnalysisRepertoire.create_from_path_collection_results(run_path / "results_dreamer.pkl")
    # plot_repertoire = ours_repertoire.replace(descriptors=jnp.mean(ours_repertoire.descriptors, axis=1), fitnesses=jnp.mean(ours_repertoire.fitnesses, axis=1))
    # fig, _ = plot_2d_map_elites_repertoire(plot_repertoire.centroids, plot_repertoire.fitnesses, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    # fig.savefig("/project/output/hierarchy/ours_fitness.png")
    # fig, _ = plot_2d_map_elites_repertoire(plot_repertoire.centroids, -jnp.linalg.norm(plot_repertoire.centroids-plot_repertoire.descriptors, axis=-1), minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    # fig.savefig("/project/output/hierarchy/ours_distance_to_goal.png")
    return ours_repertoire

def eval_smerl(run_path, actuator_failure_idx, actuator_failure_strength):
    with open(run_path / ".hydra" / "config.yaml") as f:
        config = yaml.safe_load(f)
    config = OmegaConf.load(run_path / ".hydra" / "config.yaml")

    # Init a random key
    random_key = jax.random.PRNGKey(config.seed)

    # Init environment
    env = environments.create(config.task + "failure" + "_" + config.feat,
                              episode_length=config.algo.episode_length,
                              backend=config.algo.backend,
                              qdax_wrappers_kwargs=[{}, {"actuator_failure_idx": actuator_failure_idx,
                                                         "actuator_failure_strength": actuator_failure_strength}])

    # Define config
    smerl_config = DiaynSmerlConfig(
        # SAC config
        batch_size=config.algo.batch_size,
        episode_length=config.algo.episode_length,
        tau=config.algo.soft_tau_update,
        normalize_observations=config.algo.normalize_observations,
        learning_rate=config.algo.learning_rate,
        alpha_init=config.algo.alpha_init,
        discount=config.algo.discount,
        reward_scaling=config.algo.reward_scaling,
        hidden_layer_sizes=config.algo.hidden_layer_sizes,
        fix_alpha=config.algo.fix_alpha,
        # DIAYN config
        skill_type=config.algo.skill_type,
        num_skills=config.algo.num_skills,
        descriptor_full_state=config.algo.descriptor_full_state,
        extrinsic_reward=False,
        beta=1.,
        # SMERL
        reverse=False,
        diversity_reward_scale=config.algo.diversity_reward_scale,
        smerl_target=config.algo.smerl_target,
        smerl_margin=config.algo.smerl_margin,
    )

    # Define an instance of DIAYN
    smerl = DIAYNSMERL(config=smerl_config, action_size=env.action_size)

    random_key, random_subkey_1, random_subkey_2 = jax.random.split(random_key, 3)
    fake_obs = jnp.zeros((env.observation_size + config.algo.num_skills,))
    fake_goal = jnp.zeros((config.algo.num_skills,))
    fake_actor_params = smerl._policy.init(random_subkey_1, fake_obs)
    fake_discriminator_params = smerl._discriminator.init(random_subkey_2, fake_goal)

    with open(run_path / "actor/actor.pickle", "rb") as params_file:
        state_dict = pickle.load(params_file)
    actor_params = serialization.from_state_dict(fake_actor_params, state_dict)

    with open(run_path / "discriminator/discriminator.pickle", "rb") as params_file:
        state_dict = pickle.load(params_file)
    discriminator_params = serialization.from_state_dict(fake_discriminator_params, state_dict)

    # Create grid
    resolution = 50
    grid_shape = (resolution,) * env.feat_space['vector'].shape[0]
    goals = compute_euclidean_centroids(grid_shape, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    latent_goals, _ = smerl._discriminator.apply(discriminator_params, goals)

    reset_fn = jax.jit(env.reset)

    @jax.jit
    def play_step_fn(env_state, params, latent_goal, random_key):
        actions, random_key = smerl.select_action(
                    obs=jnp.concatenate([env_state.obs, latent_goal], axis=0),
                    policy_params=params,
                    random_key=random_key,
                    deterministic=True,
                )
        state_desc = env_state.info["state_descriptor"]
        next_state = env.step(env_state, actions)

        transition = QDTransition(
            obs=env_state.obs,
            next_obs=next_state.obs,
            rewards=next_state.reward,
            dones=next_state.done,
            truncations=next_state.info["truncation"],
            actions=actions,
            state_desc=state_desc,
            next_state_desc=next_state.info["state_descriptor"],
            desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
            desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
        )

        return next_state, params, latent_goal, random_key, transition

    # Prepare the scoring function
    scoring_fn = jax.jit(functools.partial(
        scoring_actor_dc_function,
        episode_length=config.algo.episode_length,
        play_reset_fn=reset_fn,
        play_step_actor_dc_fn=play_step_fn,
        behavior_descriptor_extractor=get_feat_mean,
    ))

    @jax.jit
    def evaluate_actor(random_key, params, latent_goals):
        params = jax.tree_util.tree_map(lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), latent_goals.shape[0], axis=0), params)
        fitnesses, descriptors, extra_scores, random_key = scoring_fn(
            params, latent_goals, random_key
        )
        return fitnesses, descriptors, extra_scores, random_key
    
    fitnesses_list = []
    descriptor_list = []
    for _ in range(n_visits_per_goal):
        fitnesses, descriptors, extra_scores, random_key = evaluate_actor(random_key, actor_params, latent_goals)
        fitnesses_list.append(fitnesses)
        descriptor_list.append(descriptors)

    smerl_repertoire = AnalysisLatentRepertoire(
        centroids=goals,
        latent_goals=latent_goals,
        fitnesses=jnp.stack(fitnesses_list, axis=1),
        descriptors=jnp.stack(descriptor_list, axis=1))
    # plot_repertoire = smerl_repertoire.replace(descriptors=jnp.mean(smerl_repertoire.descriptors, axis=1), fitnesses=jnp.mean(smerl_repertoire.fitnesses, axis=1))
    # fig, _ = plot_2d_map_elites_repertoire(plot_repertoire.centroids, plot_repertoire.fitnesses, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    # fig.savefig("/project/output/hierarchy/smerl_fitness.png")
    # fig, _ = plot_2d_map_elites_repertoire(plot_repertoire.centroids, -jnp.linalg.norm(goals-descriptors, axis=-1), minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    # fig.savefig("/project/output/hierarchy/smerl_distance_to_goal.png")
    return smerl_repertoire

def eval_smerl_reverse(run_path, actuator_failure_idx, actuator_failure_strength):
    with open(run_path / ".hydra" / "config.yaml") as f:
        config = yaml.safe_load(f)
    config = OmegaConf.load(run_path / ".hydra" / "config.yaml")

    # Init a random key
    random_key = jax.random.PRNGKey(config.seed)

    # Init environment
    env = environments.create(config.task + "failure" + "_" + config.feat,
                              episode_length=config.algo.episode_length,
                              backend=config.algo.backend,
                              qdax_wrappers_kwargs=[{}, {"actuator_failure_idx": actuator_failure_idx,
                                                         "actuator_failure_strength": actuator_failure_strength}])

    # Define config
    smerl_config = DiaynSmerlConfig(
        # SAC config
        batch_size=config.algo.batch_size,
        episode_length=config.algo.episode_length,
        tau=config.algo.soft_tau_update,
        normalize_observations=config.algo.normalize_observations,
        learning_rate=config.algo.learning_rate,
        alpha_init=config.algo.alpha_init,
        discount=config.algo.discount,
        reward_scaling=config.algo.reward_scaling,
        hidden_layer_sizes=config.algo.hidden_layer_sizes,
        fix_alpha=config.algo.fix_alpha,
        # DIAYN config
        skill_type=config.algo.skill_type,
        num_skills=config.algo.num_skills,
        descriptor_full_state=config.algo.descriptor_full_state,
        extrinsic_reward=False,
        beta=1.,
        # SMERL
        reverse=True,
        diversity_reward_scale=config.algo.diversity_reward_scale,
        smerl_target=config.algo.smerl_target,
        smerl_margin=config.algo.smerl_margin,
    )

    # Define an instance of DIAYN
    smerl = DIAYNSMERL(config=smerl_config, action_size=env.action_size)

    random_key, random_subkey_1, random_subkey_2 = jax.random.split(random_key, 3)
    fake_obs = jnp.zeros((env.observation_size + config.algo.num_skills,))
    fake_goal = jnp.zeros((config.algo.num_skills,))
    fake_actor_params = smerl._policy.init(random_subkey_1, fake_obs)
    fake_discriminator_params = smerl._discriminator.init(random_subkey_2, fake_goal)

    with open(run_path / "actor/actor.pickle", "rb") as params_file:
        state_dict = pickle.load(params_file)
    actor_params = serialization.from_state_dict(fake_actor_params, state_dict)

    with open(run_path / "discriminator/discriminator.pickle", "rb") as params_file:
        state_dict = pickle.load(params_file)
    discriminator_params = serialization.from_state_dict(fake_discriminator_params, state_dict)

    # Create grid
    resolution = 50
    grid_shape = (resolution,) * env.feat_space['vector'].shape[0]
    goals = compute_euclidean_centroids(grid_shape, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    latent_goals, _ = smerl._discriminator.apply(discriminator_params, goals)

    reset_fn = jax.jit(env.reset)

    @jax.jit
    def play_step_fn(env_state, params, latent_goal, random_key):
        actions, random_key = smerl.select_action(
                    obs=jnp.concatenate([env_state.obs, latent_goal], axis=0),
                    policy_params=params,
                    random_key=random_key,
                    deterministic=True,
                )
        state_desc = env_state.info["state_descriptor"]
        next_state = env.step(env_state, actions)

        transition = QDTransition(
            obs=env_state.obs,
            next_obs=next_state.obs,
            rewards=next_state.reward,
            dones=next_state.done,
            truncations=next_state.info["truncation"],
            actions=actions,
            state_desc=state_desc,
            next_state_desc=next_state.info["state_descriptor"],
            desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
            desc_prime=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
        )

        return next_state, params, latent_goal, random_key, transition

    # Prepare the scoring function
    scoring_fn = jax.jit(functools.partial(
        scoring_actor_dc_function,
        episode_length=config.algo.episode_length,
        play_reset_fn=reset_fn,
        play_step_actor_dc_fn=play_step_fn,
        behavior_descriptor_extractor=get_feat_mean,
    ))

    @jax.jit
    def evaluate_actor(random_key, params, latent_goals):
        params = jax.tree_util.tree_map(lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), latent_goals.shape[0], axis=0), params)
        fitnesses, descriptors, extra_scores, random_key = scoring_fn(
            params, latent_goals, random_key
        )
        return fitnesses, descriptors, extra_scores, random_key
    
    fitnesses_list = []
    descriptor_list = []
    for _ in range(n_visits_per_goal):
        fitnesses, descriptors, extra_scores, random_key = evaluate_actor(random_key, actor_params, latent_goals)
        fitnesses_list.append(fitnesses)
        descriptor_list.append(descriptors)

    smerl_repertoire = AnalysisLatentRepertoire(
        centroids=goals,
        latent_goals=latent_goals,
        fitnesses=jnp.stack(fitnesses_list, axis=1),
        descriptors=jnp.stack(descriptor_list, axis=1))
    # plot_repertoire = smerl_repertoire.replace(descriptors=jnp.mean(smerl_repertoire.descriptors, axis=1), fitnesses=jnp.mean(smerl_repertoire.fitnesses, axis=1))
    # fig, _ = plot_2d_map_elites_repertoire(plot_repertoire.centroids, plot_repertoire.fitnesses, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    # fig.savefig("/project/output/hierarchy/smerl_reverse_fitness.png")
    # fig, _ = plot_2d_map_elites_repertoire(plot_repertoire.centroids, -jnp.linalg.norm(goals-descriptors, axis=-1), minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    # fig.savefig("/project/output/hierarchy/smerl_reverse_distance_to_goal.png")
    return smerl_repertoire

def eval_uvfa(run_path, actuator_failure_idx, actuator_failure_strength):
    config_path = run_path / "wandb" / "latest-run" / "files" / "config.yaml"
    with open(config_path) as f:
        config = yaml.safe_load(f)

    argv = [
    "--task={}".format(config["task"]["value"]),
    "--feat={}".format(config["feat"]["value"]),
    "--backend={}".format(config["backend"]["value"]),

    "--run.from_checkpoint={}".format(str(run_path / "checkpoint.ckpt")),
    "--envs.amount=2048",
    ]

    # Create config
    logdir = str(run_path)
    config = embodied.Config(dreamerv3.configs["defaults"])
    config = config.update(dreamerv3.configs["brax"])
    config = config.update({
    "logdir": logdir,
    "run.train_ratio": 32,
    "run.log_every": 60,  # Seconds
    "batch_size": 16,
    })
    config = embodied.Flags(config).parse(argv=argv)

    # Create logger
    logdir = embodied.Path(config.logdir)
    step = embodied.Counter()
    logger = embodied.Logger(step, [
    embodied.logger.TerminalOutput(),
    embodied.logger.JSONLOutput(logdir, "metrics.jsonl"),
    embodied.logger.TensorBoardOutput(logdir),
    # embodied.logger.WandBOutput(logdir, config),
    # embodied.logger.MLFlowOutput(logdir.name),
    ])

    # Create environment
    env = get_env(config, mode="train", actuator_failure_idx=actuator_failure_idx, actuator_failure_strength=actuator_failure_strength)

    # Create agent and replay buffer
    agent = dreamerv3.Agent(env.obs_space, env.act_space, env.feat_space, step, config)
    args = embodied.Config(
        **config.run, logdir=config.logdir,
        batch_steps=config.batch_size * config.batch_length)

    # Create goal sampler
    resolution = ImagActorCritic.get_resolution(env.feat_space, config)
    grid_shape = (resolution,) * env.feat_space['vector'].shape[0]
    goals = compute_euclidean_centroids(grid_shape, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    goal_sampler_cyclic = GoalSamplerCyclic(feat_space=env.feat_space, 
                                            goal_list=list(goals),
                                            number_visits_per_goal=n_visits_per_goal)
    embodied.run.eval_only(agent,
                           env,
                           goal_sampler=goal_sampler_cyclic,
                           period_sample_goals=float('inf'),
                           logger=logger,
                           args=args,)

    ours_repertoire = AnalysisRepertoire.create_from_path_collection_results(run_path / "results_dreamer.pkl")
    # plot_repertoire = ours_repertoire.replace(descriptors=jnp.mean(ours_repertoire.descriptors, axis=1), fitnesses=jnp.mean(ours_repertoire.fitnesses, axis=1))
    # fig, _ = plot_2d_map_elites_repertoire(plot_repertoire.centroids, plot_repertoire.fitnesses, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    # fig.savefig("/project/output/hierarchy/ours_fitness.png")
    # fig, _ = plot_2d_map_elites_repertoire(plot_repertoire.centroids, -jnp.linalg.norm(plot_repertoire.centroids-plot_repertoire.descriptors, axis=-1), minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    # fig.savefig("/project/output/hierarchy/ours_distance_to_goal.png")
    return ours_repertoire

def eval_dcg_me(run_path, actuator_failure_idx, actuator_failure_strength):
    with open(run_path / ".hydra" / "config.yaml") as f:
        config = yaml.safe_load(f)
    config = OmegaConf.load(run_path / ".hydra" / "config.yaml")

    # Init a random key
    random_key = jax.random.PRNGKey(config.seed)

    # Init environment
    env = environments.create(config.task + "failure" + "_" + config.feat,
                              episode_length=config.algo.episode_length,
                              backend=config.algo.backend,
                              qdax_wrappers_kwargs=[{}, {"actuator_failure_idx": actuator_failure_idx,
                                                         "actuator_failure_strength": actuator_failure_strength}])
    reset_fn = jax.jit(env.reset)

    # Init policy network
    policy_layer_sizes = config.algo.policy_hidden_layer_sizes + (env.action_size,)
    actor_dc_network = MLPDC(
        layer_sizes=policy_layer_sizes,
        kernel_init=jax.nn.initializers.lecun_uniform(),
        final_activation=jnp.tanh,
    )

    # Init population of controllers
    random_key, subkey = jax.random.split(random_key)
    fake_obs = jnp.zeros(shape=(env.observation_size,))
    fake_desc = jnp.zeros(shape=(env.behavior_descriptor_length,))
    fake_actor_params = actor_dc_network.init(subkey, fake_obs, fake_desc)

    with open(run_path / "actor/actor.pickle", "rb") as params_file:
        state_dict = pickle.load(params_file)
    actor_params = serialization.from_state_dict(fake_actor_params, state_dict)

    # Create grid
    resolution = 50
    grid_shape = (resolution,) * env.feat_space['vector'].shape[0]
    goals = compute_euclidean_centroids(grid_shape, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)

    def play_step_actor_dc_fn(env_state, actor_dc_params, desc, random_key):
        actions = actor_dc_network.apply(actor_dc_params, env_state.obs, desc/env.behavior_descriptor_limits[1][0])
        state_desc = env_state.info["feat"]
        next_state = env.step(env_state, actions)

        transition = QDTransition(
            obs=env_state.obs,
            next_obs=next_state.obs,
            rewards=next_state.reward,
            dones=next_state.done,
            truncations=next_state.info["truncation"],
            actions=actions,
            state_desc=state_desc,
            next_state_desc=next_state.info["feat"],
            desc=jnp.zeros(env.behavior_descriptor_length,) * jnp.nan,
            desc_prime=desc/env.behavior_descriptor_limits[1][0],
        )

        return next_state, actor_dc_params, desc, random_key, transition

    # Prepare the scoring function
    scoring_fn = jax.jit(functools.partial(
        scoring_actor_dc_function,
        episode_length=env.episode_length,
        play_reset_fn=reset_fn,
        play_step_actor_dc_fn=play_step_actor_dc_fn,
        behavior_descriptor_extractor=get_feat_mean,
    ))

    @jax.jit
    def evaluate_actor(random_key, params, goals):
        params = jax.tree_util.tree_map(lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), goals.shape[0], axis=0), params)
        fitnesses, descriptors, extra_scores, random_key = scoring_fn(
            params, goals, random_key
        )
        return fitnesses, descriptors, extra_scores, random_key
    
    fitnesses_list = []
    descriptor_list = []
    for _ in range(n_visits_per_goal):
        fitnesses, descriptors, extra_scores, random_key = evaluate_actor(random_key, actor_params, goals)
        fitnesses_list.append(fitnesses)
        descriptor_list.append(descriptors)

    smerl_repertoire = AnalysisRepertoire(
        centroids=goals,
        fitnesses=jnp.stack(fitnesses_list, axis=1),
        descriptors=jnp.stack(descriptor_list, axis=1))
    # plot_repertoire = smerl_repertoire.replace(descriptors=jnp.mean(smerl_repertoire.descriptors, axis=1), fitnesses=jnp.mean(smerl_repertoire.fitnesses, axis=1))
    # fig, _ = plot_2d_map_elites_repertoire(plot_repertoire.centroids, plot_repertoire.fitnesses, minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    # fig.savefig("/project/output/hierarchy/dcg_me_fitness.png")
    # fig, _ = plot_2d_map_elites_repertoire(plot_repertoire.centroids, -jnp.linalg.norm(goals-descriptors, axis=-1), minval=env.feat_space['vector'].low, maxval=env.feat_space['vector'].high)
    # fig.savefig("/project/output/hierarchy/dcg_me_distance_to_goal.png")
    return smerl_repertoire


n_seeds = 10
n_actuator_failure = 20
n_visits_per_goal = 1

actuator_failure_idx = [10]
actuator_name_humanoid = {
    0: "abdomen_y",
    1: "abdomen_z",
    2: "abdomen_x",
    3: "right_hip_x",
    4: "right_hip_z",
    5: "right_hip_y",
    6: "right_knee",
    7: "left_hip_x",
    8: "left_hip_z",
    9: "left_hip_y",
    10: "left_knee",
    11: "right_shoulder1",
    11: "right_shoulder2",
    12: "right_elbow",
    13: "left_shoulder1",
    14: "left_shoulder2",
    15: "left_elbow",
}
actuator_name_walker2d = {
    0: "thigh_joint",
    1: "leg_joint",
    2: "foot_joint",
    3: "thigh_left_joint",
    4: "leg_left_joint",
    5: "foot_left_joint",
}

if task == "humanoid":
    actuator_name = actuator_name_humanoid[actuator_failure_idx[0]]
elif task == "walker2d":
    actuator_name = actuator_name_walker2d[actuator_failure_idx[0]]
else:
    raise NotImplementedError

ours_seeds = list(ours_path.iterdir())
smerl_seeds = list(smerl_path.iterdir())
smerl_reverse_seeds = list(smerl_reverse_path.iterdir())
uvfa_seeds = list(uvfa_path.iterdir())
dcg_me_seeds = list(dcg_me_path.iterdir())

df = pd.DataFrame(columns=["algo", "seed", "actuator_failure_strength", "fitness", "distance_to_goal"])
for i in range(n_seeds):
    if i <= 8:
        continue
    for j, actuator_failure_strength in enumerate(jnp.linspace(0, 1, n_actuator_failure)):
        if i == 9 and j <= 8:
            continue
        print(f"seed: {i}, actuator_failure_strength: {j}/{n_actuator_failure}")
        # ours
        print(f"ours: {actuator_failure_strength}")
        ours_repertoire = eval_ours(ours_seeds[i], actuator_failure_idx, [actuator_failure_strength])
        df.loc[len(df)] = ["ours", i, actuator_failure_strength, jnp.max(jnp.median(ours_repertoire.fitnesses, axis=-1)), jnp.mean(-jnp.linalg.norm(ours_repertoire.centroids - jnp.mean(ours_repertoire.descriptors, axis=1), axis=-1))]
        
        # smerl
        print(f"smerl: {actuator_failure_strength}")
        smerl_repertoire = eval_smerl(smerl_seeds[i], actuator_failure_idx, [actuator_failure_strength])
        df.loc[len(df)] = ["smerl", i, actuator_failure_strength, jnp.max(jnp.median(smerl_repertoire.fitnesses, axis=-1)), jnp.mean(-jnp.linalg.norm(smerl_repertoire.centroids - jnp.mean(smerl_repertoire.descriptors, axis=1), axis=-1))]

        # smerl_reverse
        print(f"smerl_reverse: {actuator_failure_strength}")
        smerl_reverse_repertoire = eval_smerl_reverse(smerl_reverse_seeds[i], actuator_failure_idx, [actuator_failure_strength])
        df.loc[len(df)] = ["smerl_reverse", i, actuator_failure_strength, jnp.max(jnp.median(smerl_reverse_repertoire.fitnesses, axis=-1)), jnp.mean(-jnp.linalg.norm(smerl_reverse_repertoire.centroids - jnp.mean(smerl_reverse_repertoire.descriptors, axis=1), axis=-1))]

        # ours
        print(f"uvfa: {actuator_failure_strength}")
        uvfa_repertoire = eval_uvfa(uvfa_seeds[i], actuator_failure_idx, [actuator_failure_strength])
        df.loc[len(df)] = ["uvfa", i, actuator_failure_strength, jnp.max(jnp.median(uvfa_repertoire.fitnesses, axis=-1)), jnp.mean(-jnp.linalg.norm(uvfa_repertoire.centroids - jnp.mean(uvfa_repertoire.descriptors, axis=1), axis=-1))]

        # dcg_me
        print(f"dcg_me: {actuator_failure_strength}")
        dcg_me_repertoire = eval_dcg_me(dcg_me_seeds[i], actuator_failure_idx, [actuator_failure_strength])
        df.loc[len(df)] = ["dcg_me", i, actuator_failure_strength, jnp.max(jnp.median(dcg_me_repertoire.fitnesses, axis=-1)), jnp.mean(-jnp.linalg.norm(dcg_me_repertoire.centroids - jnp.mean(dcg_me_repertoire.descriptors, axis=1), axis=-1))]

        df.to_csv(f"/project/output/hierarchy/{task}_{actuator_name}_{n_seeds}_{n_actuator_failure}_{n_visits_per_goal}.csv")
